from layer_compiler.enum_def import Type, Op, Buf, Sched, Mecha
from layer_compiler.layer import Layer, Container
from layer_compiler.compiler import NN

container_mlp_sample = Container()

container_cnn_alex = Container()
container_cnn_google = Container()
container_cnn_vgg = Container()
container_cnn_mobile = Container()

container_rnn_asr = Container()
container_rnn_mt = Container()
container_rnn_sa = Container()

# Newly Added
container_cnn_tinyyolo  = Container()
container_cnn_resnet50  = Container()
container_cnn_efficient = Container()

corenum = 271 # 228

tinyyolo_iso = 2390785
resnet50_iso = 5097888
efficient_iso = 10092968
mobile_iso = 1867013


def cnn_tinyyolo_init(batch):
    
    conv0 = Layer(Type.CONV, batch=batch, in_dim=(416, 416, 3), kernel_dim=(3, 3), kernel_num=16, stride=1, padding=1)
    pool0 = Layer(Type.POOL, batch=batch, in_dim=(416, 416, 16), window_dim=(2, 2), stride = 2, previous_input=True)
    
    conv1 = Layer(Type.CONV, batch=batch, in_dim=(208, 208, 16), kernel_dim=(3, 3), kernel_num=32, stride=1, padding=1, previous_input=True)
    pool1 = Layer(Type.POOL, batch=batch, in_dim=(208, 208, 32), window_dim=(2, 2), stride = 2, previous_input=True)

    conv2 = Layer(Type.CONV, batch=batch, in_dim=(104, 104, 32), kernel_dim=(3, 3), kernel_num=64, stride=1, padding=1, previous_input=True)
    pool2 = Layer(Type.POOL, batch=batch, in_dim=(104, 104, 64), window_dim=(2, 2), stride = 2, previous_input=True)

    conv3 = Layer(Type.CONV, batch=batch, in_dim=(52, 52, 64), kernel_dim=(3, 3), kernel_num=128, stride=1, padding=1, previous_input=True)
    pool3 = Layer(Type.POOL, batch=batch, in_dim=(52, 52, 128), window_dim=(2, 2), stride = 2, previous_input=True)

    conv4 = Layer(Type.CONV, batch=batch, in_dim=(26, 26, 128), kernel_dim=(3, 3), kernel_num=256, stride=1, padding=1, previous_input=True)
    pool4 = Layer(Type.POOL, batch=batch, in_dim=(26, 26, 256), window_dim=(2, 2), stride = 2, previous_input=True)
    
    conv5 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 256), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    conv6 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 512), kernel_dim=(3, 3), kernel_num=1024, stride=1, padding=1, previous_input=True)
    conv7 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 1024), kernel_dim=(3, 3), kernel_num=1024, stride=1, padding=1, previous_input=True)
    conv8 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 1024), kernel_dim=(3, 3), kernel_num=125, stride=1, padding=1, previous_input=True)
    
    container_cnn_tinyyolo.push_layer(conv0)
    container_cnn_tinyyolo.push_layer(pool0)
    container_cnn_tinyyolo.push_layer(conv1)
    container_cnn_tinyyolo.push_layer(pool1)
    container_cnn_tinyyolo.push_layer(conv2)
    container_cnn_tinyyolo.push_layer(pool2)
    container_cnn_tinyyolo.push_layer(conv3)
    container_cnn_tinyyolo.push_layer(pool3)
    container_cnn_tinyyolo.push_layer(conv4)
    container_cnn_tinyyolo.push_layer(pool4)
    container_cnn_tinyyolo.push_layer(conv5)
    container_cnn_tinyyolo.push_layer(conv6)
    container_cnn_tinyyolo.push_layer(conv7)
    container_cnn_tinyyolo.push_layer(conv8)

    container_cnn_tinyyolo.isolated[1] = tinyyolo_iso
    container_cnn_tinyyolo.net_name = "TinyYoloNet"
    print("[I] cnn_tinyyolo_init")

def cnn_resnet50_init(batch):
    # stage 0
    conv0_1 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 3), kernel_dim=(7, 7), kernel_num=64, stride=2, padding=3, previous_input=False)
    pool0_1 = Layer(Type.POOL, batch=batch, in_dim=(112, 112, 64), window_dim=(3, 3), stride=2, previous_input=True)
    container_cnn_resnet50.push_layer(conv0_1)
    container_cnn_resnet50.push_layer(pool0_1)

    # stage 1
    for iter in range(3):
        conv1_1 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 64 if iter == 0 else 256), kernel_dim=(1, 1), kernel_num=64, stride=1, padding=0, previous_input=True)
        conv1_2 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 64), kernel_dim=(3, 3), kernel_num=64, stride=1, padding=1, previous_input=True)
        conv1_3 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 64), kernel_dim=(1, 1), kernel_num=256, stride=1, padding=0, previous_input=True)
        container_cnn_resnet50.push_layer(conv1_1)
        container_cnn_resnet50.push_layer(conv1_2)
        container_cnn_resnet50.push_layer(conv1_3)
    
    # stage 2
    for iter in range(4):
        if iter == 0:
            conv2_1 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 256), kernel_dim=(1, 1), kernel_num=128, stride=2, padding=0, previous_input=True)
        else:
            conv2_1 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 512), kernel_dim=(1, 1), kernel_num=128, stride=1, padding=0, previous_input=True)
        conv2_2 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 128), kernel_dim=(3, 3), kernel_num=128, stride=1, padding=1, previous_input=True)
        conv2_3 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 128), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=0, previous_input=True)
        container_cnn_resnet50.push_layer(conv2_1)
        container_cnn_resnet50.push_layer(conv2_2)
        container_cnn_resnet50.push_layer(conv2_3)

    # stage 3
    for iter in range(6):
        if iter == 0:
            conv3_1 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 512), kernel_dim=(1, 1), kernel_num=256, stride=2, padding=0, previous_input=True)
        else:
            conv3_1 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 1024), kernel_dim=(1, 1), kernel_num=256, stride=1, padding=0, previous_input=True)
        conv3_2 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 256), kernel_dim=(3, 3), kernel_num=256, stride=1, padding=1, previous_input=True)
        conv3_3 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 256), kernel_dim=(1, 1), kernel_num=1024, stride=1, padding=0, previous_input=True)
        container_cnn_resnet50.push_layer(conv3_1)
        container_cnn_resnet50.push_layer(conv3_2)
        container_cnn_resnet50.push_layer(conv3_3)

    # stage 4
    for iter in range(6):
        if iter == 0:
            conv4_1 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 1024), kernel_dim=(1, 1), kernel_num=512, stride=2, padding=0, previous_input=True)
        else:
            conv4_1 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, 2048), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=0, previous_input=True)
        conv4_2 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, 512), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
        conv4_3 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, 512), kernel_dim=(1, 1), kernel_num=2048, stride=1, padding=0, previous_input=True)
        container_cnn_resnet50.push_layer(conv4_1)
        container_cnn_resnet50.push_layer(conv4_2)
        container_cnn_resnet50.push_layer(conv4_3)

    fc = Layer(Type.FC, batch=batch, in_dim=7 * 7 * 2048, out_dim=1000, previous_input=True)
    container_cnn_resnet50.push_layer(fc)

    container_cnn_resnet50.isolated[1] = resnet50_iso
    container_cnn_resnet50.net_name = "ResNet50"
    print("[I] cnn_resnet50_init")

def cnn_efficient_init(batch):
    # conv3x3
    conv3x3 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 3), kernel_dim=(3, 3), kernel_num=32, stride=2, padding=1, previous_input=False)
    container_cnn_efficient.push_layer(conv3x3)

    # MBConv1 （112x112x32 -> 112x112x16）
    inpt_channel = 32
    extension = 1
    convup1 = Layer(Type.CONV, batch=batch, in_dim=(112, 112, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw1 = Layer(Type.DEPTH, batch=batch, in_dim=(112, 112, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn1 = Layer(Type.CONV, batch=batch, in_dim=(112, 112, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=16, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup1)
    container_cnn_efficient.push_layer(convdw1)
    container_cnn_efficient.push_layer(convdn1)

    # -----------------------------------------------------

    # MBConv6 3x3 (112x112x16 -> 56x56x24)
    inpt_channel = 16
    extension = 6
    convup2 = Layer(Type.CONV, batch=batch, in_dim=(112, 112, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=2, padding=0, previous_input=False)
    convdw2 = Layer(Type.DEPTH, batch=batch, in_dim=(56, 56, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn2 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=24, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup2)
    container_cnn_efficient.push_layer(convdw2)
    container_cnn_efficient.push_layer(convdn2)

    # MBConv6 3x3 (56x56x24 -> 56x56x24)
    inpt_channel = 24
    extension = 6
    convup2 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw2 = Layer(Type.DEPTH, batch=batch, in_dim=(56, 56, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn2 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=24, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup2)
    container_cnn_efficient.push_layer(convdw2)
    container_cnn_efficient.push_layer(convdn2)

    # -----------------------------------------------------

    # MBConv6 5x5 (56x56x24 -> 28x28x40)
    inpt_channel = 24
    extension = 6
    convup3 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=2, padding=0, previous_input=False)
    convdw3 = Layer(Type.DEPTH, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
    convdn3 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=40, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup3)
    container_cnn_efficient.push_layer(convdw3)
    container_cnn_efficient.push_layer(convdn3)

    # MBConv6 5x5 (28x28x40 -> 28x28x40)
    inpt_channel = 40
    extension = 6
    convup3 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw3 = Layer(Type.DEPTH, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
    convdn3 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=40, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup3)
    container_cnn_efficient.push_layer(convdw3)
    container_cnn_efficient.push_layer(convdn3)

    # -----------------------------------------------------

    # MBConv6 3x3 (28x28x40 -> 28x28x80)
    inpt_channel = 40
    extension = 6
    convup4 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw4 = Layer(Type.DEPTH, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn4 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=80, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup4)
    container_cnn_efficient.push_layer(convdw4)
    container_cnn_efficient.push_layer(convdn4)

    # MBConv6 3x3 (28x28x80 -> 28x28x80)
    inpt_channel = 80
    extension = 6
    convup4 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw4 = Layer(Type.DEPTH, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn4 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=80, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup4)
    container_cnn_efficient.push_layer(convdw4)
    container_cnn_efficient.push_layer(convdn4)

    # MBConv6 3x3 (28x28x80 -> 28x28x80)
    inpt_channel = 80
    extension = 6
    convup4 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw4 = Layer(Type.DEPTH, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn4 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=80, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup4)
    container_cnn_efficient.push_layer(convdw4)
    container_cnn_efficient.push_layer(convdn4)
    
    # -----------------------------------------------------

    # MBConv6 5x5 (28x28x80 -> 14x14x112)
    inpt_channel = 80
    extension = 6
    convup5 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=2, padding=0, previous_input=False)
    convdw5 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
    convdn5 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=112, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup5)
    container_cnn_efficient.push_layer(convdw5)
    container_cnn_efficient.push_layer(convdn5)

    # MBConv6 5x5 (14x14x112 -> 14x14x112)
    inpt_channel = 112
    extension = 6
    convup5 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw5 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
    convdn5 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=112, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup5)
    container_cnn_efficient.push_layer(convdw5)
    container_cnn_efficient.push_layer(convdn5)

    # MBConv6 5x5 (14x14x112 -> 14x14x112)
    inpt_channel = 112
    extension = 6
    convup5 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw5 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
    convdn5 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=112, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup5)
    container_cnn_efficient.push_layer(convdw5)
    container_cnn_efficient.push_layer(convdn5)

    # -----------------------------------------------------

    # MBConv6 5x5 (14x14x112 -> 7x7x192)
    inpt_channel = 112
    extension = 6
    convup6 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=2, padding=0, previous_input=False)
    convdw6 = Layer(Type.DEPTH, batch=batch, in_dim=(7, 7, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
    convdn6 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=192, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup6)
    container_cnn_efficient.push_layer(convdw6)
    container_cnn_efficient.push_layer(convdn6)

    # MBConv6 5x5 (7x7x192 -> 7x7x192)
    for iter in range(4):
        inpt_channel = 192
        extension = 6
        convup6 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
        convdw6 = Layer(Type.DEPTH, batch=batch, in_dim=(7, 7, inpt_channel * extension), kernel_dim=(5, 5), stride=1, padding=2, previous_input=True)
        convdn6 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=192, stride=1, padding=0, previous_input=False)
        container_cnn_efficient.push_layer(convup6)
        container_cnn_efficient.push_layer(convdw6)
        container_cnn_efficient.push_layer(convdn6)
    
    # -----------------------------------------------------

    # MBConv6 5x5 (7x7x192 -> 7x7x192)
    inpt_channel = 192
    extension = 6
    convup7 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, inpt_channel), kernel_dim=(1, 1), kernel_num=inpt_channel * extension, stride=1, padding=0, previous_input=False)
    convdw7 = Layer(Type.DEPTH, batch=batch, in_dim=(7, 7, inpt_channel * extension), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    convdn7 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, inpt_channel * extension), kernel_dim=(1, 1), kernel_num=320, stride=1, padding=0, previous_input=False)
    container_cnn_efficient.push_layer(convup7)
    container_cnn_efficient.push_layer(convdw7)
    container_cnn_efficient.push_layer(convdn7)

    fc = Layer(Type.FC, batch=batch, in_dim=7 * 7 * 320, out_dim=1000, previous_input=True)
    container_cnn_efficient.push_layer(fc)

    container_cnn_efficient.net_name = "EfficientNet"
    container_cnn_efficient.isolated[1] = efficient_iso
    print("[I] cnn_efficient_init")

def all_init(batch, length):
    four_mlp_init(batch)
    rnn_asr_init(batch, length)
    cnn_alex_init(batch)
    cnn_vgg_init(batch)
    cnn_google_init(batch)
    cnn_mobile_init(batch)

    cnn_tinyyolo_init(batch)
    cnn_resnet50_init(batch)
    cnn_efficient_init(batch)
    
    rnn_mt_init(batch, length)
    rnn_sa_init(batch, length)


# 4-layer Sample 
def four_mlp_init(batch):
    layer1 = Layer(Type.FC, batch=batch, in_dim=100, out_dim=400)
    layer2 = Layer(Type.FC, batch=batch, in_dim=400, out_dim=400, previous_input=True)
    layer3 = Layer(Type.FC, batch=batch, in_dim=400, out_dim=400, previous_input=True)
    layer4 = Layer(Type.FC, batch=batch, in_dim=400, out_dim=10, previous_input=True)
    container_mlp_sample.push_layer(layer1)
    container_mlp_sample.push_layer(layer2)
    container_mlp_sample.push_layer(layer3)
    container_mlp_sample.push_layer(layer4)

# Alexnet
def cnn_alex_init(batch):
    conv1 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 3), kernel_dim=(11, 11), kernel_num=96, stride=4, padding=0)
    pool1 = Layer(Type.POOL, batch=batch, in_dim=(55, 55, 96), window_dim=(3, 3), stride = 1, previous_input=True)
    conv2 = Layer(Type.CONV, batch=batch, in_dim=(27, 27, 96), kernel_dim=(5, 5), kernel_num=256, stride=1, padding=2, previous_input=True)
    pool2 = Layer(Type.POOL, batch=batch, in_dim=(27, 27, 256), window_dim=(3, 3), stride=2, previous_input=True)
    conv3 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 256), kernel_dim=(3, 3), kernel_num=384, stride=1, padding=1, previous_input=True)
    conv4 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 384), kernel_dim=(3, 3), kernel_num=384, stride=1, padding=1, previous_input=True)
    conv5 = Layer(Type.CONV, batch=batch, in_dim=(13, 13, 384), kernel_dim=(3, 3), kernel_num=256, stride=1, padding=1, previous_input=True)
    pool5 = Layer(Type.POOL, batch=batch, in_dim=(13, 13, 256), window_dim=(3, 3), stride=2, previous_input=True)
    fc6 = Layer(Type.FC, batch=batch, in_dim=9216, out_dim=4096, previous_input=True)
    fc7 = Layer(Type.FC, batch=batch, in_dim=4096, out_dim=1000, previous_input=True)

    container_cnn_alex.push_layer(conv1)
    container_cnn_alex.push_layer(pool1)
    container_cnn_alex.push_layer(conv2)
    container_cnn_alex.push_layer(pool2)
    container_cnn_alex.push_layer(conv3)
    container_cnn_alex.push_layer(conv4)
    container_cnn_alex.push_layer(conv5)
    container_cnn_alex.push_layer(pool5)
    container_cnn_alex.push_layer(fc6)
    container_cnn_alex.push_layer(fc7)

    container_cnn_alex.net_name = "AlexNet"
    container_cnn_alex.isolated[1] = 1697722
    container_cnn_alex.isolated[4] = 6473754


# VGG16
# [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], avgpool
# N -> next channel num, kernel 3x3, stride=1, pad=1
def cnn_vgg_init(batch):
    conv1 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 3), kernel_dim=(3, 3), kernel_num=64, stride=1, padding=1)
    conv2 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 64), kernel_dim=(3, 3), kernel_num=64, stride=1, padding=1, previous_input=True)
    pool2 = Layer(Type.POOL, batch=batch, in_dim=(224, 224, 64), window_dim=(2, 2), stride=2, previous_input=True)

    container_cnn_vgg.push_layer(conv1)
    container_cnn_vgg.push_layer(conv2)
    container_cnn_vgg.push_layer(pool2)

    conv3 = Layer(Type.CONV, batch=batch, in_dim=(112, 112, 64), kernel_dim=(3, 3), kernel_num=128, stride=1, padding=1, previous_input=True)
    conv4 = Layer(Type.CONV, batch=batch, in_dim=(112, 112, 128), kernel_dim=(3, 3), kernel_num=128, stride=1, padding=1, previous_input=True)
    pool4 = Layer(Type.POOL, batch=batch, in_dim=(112, 112, 128), window_dim=(2, 2), stride=2, previous_input=True)

    container_cnn_vgg.push_layer(conv3)
    container_cnn_vgg.push_layer(conv4)
    container_cnn_vgg.push_layer(pool4)

    conv5 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 128), kernel_dim=(3, 3), kernel_num=256, stride=1, padding=1, previous_input=True)
    conv6 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 256), kernel_dim=(3, 3), kernel_num=256, stride=1, padding=1, previous_input=True)
    conv7 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 256), kernel_dim=(3, 3), kernel_num=256, stride=1, padding=1, previous_input=True)
    pool7 = Layer(Type.POOL, batch=batch, in_dim=(56, 56, 256), window_dim=(2, 2), stride=2, previous_input=True)

    container_cnn_vgg.push_layer(conv5)
    container_cnn_vgg.push_layer(conv6)
    container_cnn_vgg.push_layer(conv7)
    container_cnn_vgg.push_layer(pool7)

    conv8 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 256), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    conv9 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 512), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    conv10 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 512), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    pool10 = Layer(Type.POOL, batch=batch, in_dim=(28, 28, 512), window_dim=(2, 2), stride=2, previous_input=True)

    container_cnn_vgg.push_layer(conv8)
    container_cnn_vgg.push_layer(conv9)
    container_cnn_vgg.push_layer(conv10)
    container_cnn_vgg.push_layer(pool10)

    conv11 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    conv12 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    conv13 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), kernel_num=512, stride=1, padding=1, previous_input=True)
    pool13 = Layer(Type.POOL, batch=batch, in_dim=(14, 14, 512), window_dim=(2, 2), stride=2, previous_input=True)
    pool14 = Layer(Type.POOL, batch=batch, in_dim=(7, 7, 512), window_dim=(2, 2), stride=2, previous_input=True)

    container_cnn_vgg.push_layer(conv11)
    container_cnn_vgg.push_layer(conv12)
    container_cnn_vgg.push_layer(conv13)
    container_cnn_vgg.push_layer(pool13)
    container_cnn_vgg.push_layer(pool14)

    fc1 = Layer(Type.FC, batch=batch, in_dim=4096, out_dim=4096, previous_input=True)
    fc2 = Layer(Type.FC, batch=batch, in_dim=4096, out_dim=4096, previous_input=True)
    fc3 = Layer(Type.FC, batch=batch, in_dim=4096, out_dim=1000, previous_input=True)

    container_cnn_vgg.push_layer(fc1)
    container_cnn_vgg.push_layer(fc2)
    container_cnn_vgg.push_layer(fc3)

    container_cnn_vgg.net_name = "VGG16"
    container_cnn_vgg.isolated[1] = 6150050
    container_cnn_vgg.isolated[4] = 77632498
 
# GoogLeNet
# https://arxiv.org/pdf/1409.4842.pdf
def cnn_google_init(batch):
    conv1 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 3), kernel_dim=(7, 7), kernel_num=64, stride=2, padding=3, previous_input=False)
    pool1 = Layer(Type.POOL, batch=batch, in_dim=(112, 112, 64), window_dim=(3, 3), stride=2)
    conv2 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 64), kernel_dim=(3, 3), kernel_num=192, stride=1, padding=1, previous_input=True)
    pool2 = Layer(Type.POOL, batch=batch, in_dim=(56, 56, 192), window_dim=(3, 3), stride=2)

    container_cnn_google.push_layer(conv1)
    container_cnn_google.push_layer(pool1)
    container_cnn_google.push_layer(conv2)
    container_cnn_google.push_layer(pool2)

    next_in = push_inception(batch, (28, 28, 192), (64, 96, 128, 16, 32, 32))
    next_in = push_inception(batch, next_in, (128, 128, 192, 32, 96, 64))
    
    pool3 = Layer(Type.POOL, batch=batch, in_dim=next_in, window_dim=(3, 3), stride=2)
    container_cnn_google.push_layer(pool3)

    next_in = push_inception(batch, (14, 14, 480), (192, 96, 208, 16, 48, 64))
    next_in = push_inception(batch, next_in, (160, 112, 224, 24, 64, 64))
    next_in = push_inception(batch, next_in, (128, 128, 256, 24, 64, 64))
    next_in = push_inception(batch, next_in, (112, 144, 288, 32, 64, 64))
    next_in = push_inception(batch, next_in, (256, 160, 320, 32, 128, 128))

    pool4 = Layer(Type.POOL, batch=batch, in_dim=next_in, window_dim=(3, 3), stride=2)
    container_cnn_google.push_layer(pool4)

    next_in = push_inception(batch, (7, 7, 832), (256, 160, 320, 32, 128, 128))
    next_in = push_inception(batch, next_in, (384, 192, 384, 48, 128, 128))

    pool5 = Layer(Type.POOL, batch=batch, in_dim=next_in, window_dim=(7, 7), stride=1)
    container_cnn_google.push_layer(pool5)

    fc = Layer(Type.FC, batch=batch, in_dim=1024, out_dim=1000, previous_input=True)
    container_cnn_vgg.push_layer(fc)

    container_cnn_google.net_name = "GoogLeNet"
    container_cnn_google.isolated[1] = 813951
    container_cnn_google.isolated[4] = 11873594


def push_inception(batch: int, in_dim: tuple, channel_dim: tuple):
    # 1x1
    conv1_input = Layer(Type.CONV, batch=batch, in_dim=in_dim, kernel_dim=(1, 1), kernel_num=channel_dim[0], stride=1, padding=0, previous_input=True)
    container_cnn_google.push_layer(conv1_input)

    # 1x1 -> 3x3
    conv1_no_in = Layer(Type.CONV, batch=batch, in_dim=in_dim, kernel_dim=(1, 1), kernel_num=channel_dim[1], stride=1, padding=0, previous_input=True)
    conv3 = Layer(Type.CONV, batch=batch, in_dim=(in_dim[0], in_dim[1], channel_dim[1]), kernel_dim=(3, 3), kernel_num=channel_dim[2], stride=1, padding=1, previous_input=True)
    container_cnn_google.push_layer(conv1_no_in)
    container_cnn_google.push_layer(conv3)

    # 1x1 -> 5x5
    conv1_no_in = Layer(Type.CONV, batch=batch, in_dim=in_dim, kernel_dim=(1, 1), kernel_num=channel_dim[3], stride=1, padding=0, previous_input=True)
    conv5 = Layer(Type.CONV, batch=batch, in_dim=(in_dim[0], in_dim[1], channel_dim[3]), kernel_dim=(5, 5), kernel_num=channel_dim[4], stride=1, padding=2, previous_input=True)
    container_cnn_google.push_layer(conv1_no_in)
    container_cnn_google.push_layer(conv5)

    # pool -> 1x1
    pool = Layer(Type.POOL, batch=batch, in_dim=in_dim, window_dim=(3, 3), stride=1, previous_input=True)
    conv1_no_in = Layer(Type.CONV, batch=batch, in_dim=in_dim, kernel_dim=(1, 1), kernel_num=channel_dim[5], stride=1, padding=0, previous_input=True)
    container_cnn_google.push_layer(pool)
    container_cnn_google.push_layer(conv1_no_in)

    return (in_dim[0], in_dim[1], channel_dim[0]+channel_dim[2]+channel_dim[4]+channel_dim[5])

# MobileNet
def cnn_mobile_init(batch):
    conv1 = Layer(Type.CONV, batch=batch, in_dim=(224, 224, 3), kernel_dim=(3, 3), kernel_num=32, stride=2, padding=1, previous_input=False)
    container_cnn_mobile.push_layer(conv1)

    convdw2 = Layer(Type.DEPTH, batch=batch, in_dim=(112, 112, 32), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv2 = Layer(Type.CONV, batch=batch, in_dim=(112, 112, 32), kernel_dim=(1, 1), kernel_num=64, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw2)
    container_cnn_mobile.push_layer(conv2)

    convdw3 = Layer(Type.DEPTH, batch=batch, in_dim=(112, 112, 64), kernel_dim=(3, 3), stride=2, padding=1, previous_input=True)
    conv3 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 64), kernel_dim=(1, 1), kernel_num=128, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw3)
    container_cnn_mobile.push_layer(conv3)

    convdw4 = Layer(Type.DEPTH, batch=batch, in_dim=(56, 56, 128), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv4 = Layer(Type.CONV, batch=batch, in_dim=(56, 56, 128), kernel_dim=(1, 1), kernel_num=128, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw4)
    container_cnn_mobile.push_layer(conv4)

    convdw5= Layer(Type.DEPTH, batch=batch, in_dim=(56, 56, 128), kernel_dim=(3, 3), stride=2, padding=1, previous_input=True)
    conv5 = Layer(Type.CONV, batch=batch, in_dim=(28, 28, 128), kernel_dim=(1, 1), kernel_num=256, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw5)
    container_cnn_mobile.push_layer(conv5)

    convdw6 = Layer(Type.DEPTH, batch=batch, in_dim=(28, 28, 256), kernel_dim=(3, 3), stride=2, padding=1, previous_input=True)
    conv6 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 256), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw6)
    container_cnn_mobile.push_layer(conv6)

    convdw7 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv7 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw7)
    container_cnn_mobile.push_layer(conv7)

    convdw8 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv8 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw8)
    container_cnn_mobile.push_layer(conv8)

    convdw9 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv9 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw9)
    container_cnn_mobile.push_layer(conv9)

    convdw10 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv10 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw10)
    container_cnn_mobile.push_layer(conv10)

    convdw11 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv11 = Layer(Type.CONV, batch=batch, in_dim=(14, 14, 512), kernel_dim=(1, 1), kernel_num=512, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw11)
    container_cnn_mobile.push_layer(conv11)

    convdw12 = Layer(Type.DEPTH, batch=batch, in_dim=(14, 14, 512), kernel_dim=(3, 3), stride=2, padding=1, previous_input=True)
    conv12 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, 512), kernel_dim=(1, 1), kernel_num=1024, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw12)
    container_cnn_mobile.push_layer(conv12)

    convdw13 = Layer(Type.DEPTH, batch=batch, in_dim=(7, 7, 1024), kernel_dim=(3, 3), stride=1, padding=1, previous_input=True)
    conv13 = Layer(Type.CONV, batch=batch, in_dim=(7, 7, 1024), kernel_dim=(1, 1), kernel_num=1024, stride=1, padding=1, previous_input=True)
    container_cnn_mobile.push_layer(convdw13)
    container_cnn_mobile.push_layer(conv13)

    pool = Layer(Type.POOL, batch=batch, in_dim=(7, 7, 1024), window_dim=(7, 7), stride=1, previous_input=True)
    fc = Layer(Type.FC, batch=batch, in_dim=1024, out_dim=1000, previous_input=True)
    container_cnn_mobile.push_layer(pool)
    container_cnn_mobile.push_layer(fc)
    
    container_cnn_mobile.net_name = "MobileNet"
    container_cnn_mobile.isolated[1] = mobile_iso
    container_cnn_mobile.isolated[4] = 25908758

# ASR: Listen, Attend and Spell
# https://arxiv.org/pdf/1508.01211.pdf
# config: https://github.com/kaituoxu/Listen-Attend-Spell/blob/master/egs/aishell/run.sh#L21
'''
Model Architecture
[Encoder]
1) pyramidal Bi-LSTM with 3 layer (ith layer's input is it)
[Decoder]
1) 2 layer Bi-LSTM
2) Attention
3) 2-layer MLP
'''
def rnn_asr_init(batch, length):
    # Assumption: output = int(input/4)
    # Embedding layer is omitted

    # listener (encoder)
    N = batch
    Ti = length
    To = int(length/4)
    H = 256 # hidden, for single direction, i.e. 512 total
    D = 240 # input
    D_EMBED = 512
    D_HIDDEN = 512

    # encoder
    # first layer(256, bidirectional lstm)
    for i in range(Ti):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=True, previous_input=False)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=True, previous_input=False)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=False, previous_input=False)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=False, previous_input=False)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
    
    # second layer
    for i in range(int(Ti/2)):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H*2, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H*2, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H*2, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H*2, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
    
    # third layer
    for i in range(int(Ti/4)):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*4, h_dim=H*4, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=H*4, h_dim=H*4, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*4, h_dim=H*4, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=H*4, h_dim=H*4, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
    
    # decoder
    # first layer
    for i in range(To):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)

    # second layer
    for i in range(To):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=True, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D_EMBED, h_dim=D_HIDDEN, no_hidden=False, previous_input=True)
            container_rnn_asr.push_layer(layer_lstm1_bi2)

    # attention
    attention_score = Layer(Type.GEMM, batch=N, gemm_m=To, gemm_k=D_HIDDEN, gemm_n=Ti, previous_input=True)
    container_rnn_asr.push_layer(attention_score)
    attention_output = Layer(Type.GEMM, batch=N, gemm_m=To, gemm_k=Ti, gemm_n=D_HIDDEN, previous_input=True)
    container_rnn_asr.push_layer(attention_output)

    # mlp
    layer1 = Layer(Type.FC, batch=N, in_dim=D_HIDDEN*2, out_dim=D_HIDDEN)
    layer2 = Layer(Type.FC, batch=N, in_dim=D_HIDDEN, out_dim=D)
    container_rnn_asr.push_layer(layer1)   
    container_rnn_asr.push_layer(layer2)

    container_rnn_asr.net_name = "Automatic Speech Recognition"
    container_rnn_asr.isolated[1] = 13440293
    container_rnn_asr.isolated[4] = 13464404

# GNMTv2
def rnn_mt_init(batch, length):
    # Assumption: output = input
    # Embedding layer is omitted

    N = batch
    Ti = length
    To = length
    H = 256 # hidden, for single direction, i.e. 512 total
    D = 240 # input
    D_HIDDEN = 1024

    # encoder
    # first layer(256, bidirectional lstm)
    for i in range(Ti):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=True, previous_input=False)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=True, previous_input=False)
            container_rnn_mt.push_layer(layer_lstm1_bi2)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=False, previous_input=False)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
            layer_lstm1_bi2 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=False, previous_input=False)
            container_rnn_mt.push_layer(layer_lstm1_bi2)
    
    # second layer
    for i in range(Ti):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
    
    # third layer
    for i in range(Ti):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)

    # fourth layer
    for i in range(Ti):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
    
    # decoder
    for i in range(To):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=D, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
    
    # second layer
    for i in range(To):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
    
    # third layer
    for i in range(To):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)

    # fourth layer
    for i in range(To):
        if i == 0:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=True, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)
        else:
            layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H*2, h_dim=H, no_hidden=False, previous_input=True)
            container_rnn_mt.push_layer(layer_lstm1_bi1)

    # attention layer
    # Bahdanau attention
    # https://github.com/tensorflow/nmt#background-on-the-attention-mechanism
    # not accurate
    for i in range(To):
        layer_t = Layer(Type.GEMM, batch=1, gemm_m=N, gemm_k=H, gemm_n=H)
        layer_h = Layer(Type.GEMM, batch=1, gemm_m=N, gemm_k=H, gemm_n=H)
        container_rnn_mt.push_layer(layer_t)
        container_rnn_mt.push_layer(layer_h)

    # fc layer
    for i in range(To):
        fc = Layer(Type.FC, batch=N, in_dim=H, out_dim=H)
        container_rnn_mt.push_layer(fc)

    container_rnn_mt.net_name = 'Machine Translation'
    container_rnn_mt.isolated[1] = 9911366
    container_rnn_mt.isolated[4] = 9911366

# Sentimental Analysis
# https://github.com/mlperf/training/blob/master/sentiment_analysis/paddle/train.py#L48
def rnn_sa_init(batch, length):
    # embedding layer is omitted

    N = batch
    Ti = length
    H = 1024

    for i in range(Ti):
        layer_lstm1_bi1 = Layer(Type.LSTM, batch=N, in_dim=H, h_dim=H, no_hidden=True, previous_input=False)
        container_rnn_sa.push_layer(layer_lstm1_bi1)

    fc = Layer(Type.FC, batch=N, in_dim=H, out_dim=H)
    container_rnn_sa.push_layer(fc)
    
    container_rnn_sa.net_name = 'Sentiment Analysis'
    container_rnn_sa.isolated[1] = 4469462
    container_rnn_sa.isolated[4] = 4489981